{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c4600e50",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Approximate Training\n",
    ":label:`sec_approx_train`\n",
    "\n",
    "Recall our discussions in :numref:`sec_word2vec`.\n",
    "The main idea of the skip-gram model is\n",
    "using softmax operations to calculate\n",
    "the conditional probability of\n",
    "generating a context word $w_o$\n",
    "based on the given center word $w_c$\n",
    "in :eqref:`eq_skip-gram-softmax`,\n",
    "whose corresponding logarithmic loss is given by\n",
    "the opposite of :eqref:`eq_skip-gram-log`.\n",
    "\n",
    "\n",
    "\n",
    "Due to the nature of the softmax operation,\n",
    "since a context word may be anyone in the\n",
    "dictionary $\\mathcal{V}$,\n",
    "the opposite of :eqref:`eq_skip-gram-log`\n",
    "contains the summation\n",
    "of items as many as the entire size of the vocabulary.\n",
    "Consequently,\n",
    "the gradient calculation\n",
    "for the skip-gram model\n",
    "in :eqref:`eq_skip-gram-grad`\n",
    "and that\n",
    "for the continuous bag-of-words model\n",
    "in :eqref:`eq_cbow-gradient`\n",
    "both contain\n",
    "the summation.\n",
    "Unfortunately,\n",
    "the computational cost\n",
    "for such gradients\n",
    "that sum over\n",
    "a large dictionary\n",
    "(often with\n",
    "hundreds of thousands or millions of words)\n",
    "is huge!\n",
    "\n",
    "In order to reduce the aforementioned computational complexity, this section will introduce two approximate training methods:\n",
    "*negative sampling* and *hierarchical softmax*.\n",
    "Due to the similarity\n",
    "between the skip-gram model and\n",
    "the continuous bag of words model,\n",
    "we will just take the skip-gram model as an example\n",
    "to describe these two approximate training methods.\n",
    "\n",
    "## Negative Sampling\n",
    ":label:`subsec_negative-sampling`\n",
    "\n",
    "\n",
    "Negative sampling modifies the original objective function.\n",
    "Given the context window of a center word $w_c$,\n",
    "the fact that any (context) word $w_o$\n",
    "comes from this context window\n",
    "is considered as an event with the probability\n",
    "modeled by\n",
    "\n",
    "\n",
    "$$P(D=1\\mid w_c, w_o) = \\sigma(\\mathbf{u}_o^\\top \\mathbf{v}_c),$$\n",
    "\n",
    "where $\\sigma$ uses the definition of the sigmoid activation function:\n",
    "\n",
    "$$\\sigma(x) = \\frac{1}{1+\\exp(-x)}.$$\n",
    ":eqlabel:`eq_sigma-f`\n",
    "\n",
    "Let's begin by\n",
    "maximizing the joint probability of\n",
    "all such events in text sequences\n",
    "to train word embeddings.\n",
    "Specifically,\n",
    "given a text sequence of length $T$,\n",
    "denote by $w^{(t)}$ the word at time step $t$\n",
    "and let the context window size be $m$,\n",
    "consider maximizing the joint probability\n",
    "\n",
    "\n",
    "$$ \\prod_{t=1}^{T} \\prod_{-m \\leq j \\leq m,\\ j \\neq 0} P(D=1\\mid w^{(t)}, w^{(t+j)}).$$\n",
    ":eqlabel:`eq-negative-sample-pos`\n",
    "\n",
    "\n",
    "However,\n",
    ":eqref:`eq-negative-sample-pos`\n",
    "only considers those events\n",
    "that involve positive examples.\n",
    "As a result,\n",
    "the joint probability in\n",
    ":eqref:`eq-negative-sample-pos`\n",
    "is maximized to 1\n",
    "only if all the word vectors are equal to infinity.\n",
    "Of course,\n",
    "such results are meaningless.\n",
    "To make the objective function\n",
    "more meaningful,\n",
    "*negative sampling*\n",
    "adds negative examples sampled\n",
    "from a predefined distribution.\n",
    "\n",
    "Denote by $S$\n",
    "the event that\n",
    "a context word $w_o$ comes from\n",
    "the context window of a center word $w_c$.\n",
    "For this event involving $w_o$,\n",
    "from a predefined distribution $P(w)$\n",
    "sample $K$ *noise words*\n",
    "that are not from this context window.\n",
    "Denote by $N_k$\n",
    "the event that\n",
    "a noise word $w_k$ ($k=1, \\ldots, K$)\n",
    "does not come from\n",
    "the context window of $w_c$.\n",
    "Assume that\n",
    "these events involving\n",
    "both the positive example and negative examples\n",
    "$S, N_1, \\ldots, N_K$ are mutually independent.\n",
    "Negative sampling\n",
    "rewrites the joint probability (involving only positive examples)\n",
    "in :eqref:`eq-negative-sample-pos`\n",
    "as\n",
    "\n",
    "$$ \\prod_{t=1}^{T} \\prod_{-m \\leq j \\leq m,\\ j \\neq 0} P(w^{(t+j)} \\mid w^{(t)}),$$\n",
    "\n",
    "where the conditional probability is approximated through\n",
    "events $S, N_1, \\ldots, N_K$:\n",
    "\n",
    "$$ P(w^{(t+j)} \\mid w^{(t)}) =P(D=1\\mid w^{(t)}, w^{(t+j)})\\prod_{k=1,\\ w_k \\sim P(w)}^K P(D=0\\mid w^{(t)}, w_k).$$\n",
    ":eqlabel:`eq-negative-sample-conditional-prob`\n",
    "\n",
    "Denote by\n",
    "$i_t$ and $h_k$\n",
    "the indices of\n",
    "a word $w^{(t)}$ at time step $t$\n",
    "of a text sequence\n",
    "and a noise word $w_k$,\n",
    "respectively.\n",
    "The logarithmic loss with respect to the conditional probabilities in :eqref:`eq-negative-sample-conditional-prob` is\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "-\\log P(w^{(t+j)} \\mid w^{(t)})\n",
    "=& -\\log P(D=1\\mid w^{(t)}, w^{(t+j)}) - \\sum_{k=1,\\ w_k \\sim P(w)}^K \\log P(D=0\\mid w^{(t)}, w_k)\\\\\n",
    "=&-  \\log\\, \\sigma\\left(\\mathbf{u}_{i_{t+j}}^\\top \\mathbf{v}_{i_t}\\right) - \\sum_{k=1,\\ w_k \\sim P(w)}^K \\log\\left(1-\\sigma\\left(\\mathbf{u}_{h_k}^\\top \\mathbf{v}_{i_t}\\right)\\right)\\\\\n",
    "=&-  \\log\\, \\sigma\\left(\\mathbf{u}_{i_{t+j}}^\\top \\mathbf{v}_{i_t}\\right) - \\sum_{k=1,\\ w_k \\sim P(w)}^K \\log\\sigma\\left(-\\mathbf{u}_{h_k}^\\top \\mathbf{v}_{i_t}\\right).\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "\n",
    "We can see that\n",
    "now the computational cost for gradients\n",
    "at each training step\n",
    "has nothing to do with the dictionary size,\n",
    "but linearly depends on $K$.\n",
    "When setting the hyperparameter $K$\n",
    "to a smaller value,\n",
    "the computational cost for gradients\n",
    "at each training step with negative sampling\n",
    "is smaller.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "## Hierarchical Softmax\n",
    "\n",
    "As an alternative approximate training method,\n",
    "*hierarchical softmax*\n",
    "uses the binary tree,\n",
    "a data structure\n",
    "illustrated in :numref:`fig_hi_softmax`,\n",
    "where each leaf node\n",
    "of the tree represents\n",
    "a word in dictionary $\\mathcal{V}$.\n",
    "\n",
    "![Hierarchical softmax for approximate training, where each leaf node of the tree represents a word in the dictionary.](../img/hi-softmax.svg)\n",
    ":label:`fig_hi_softmax`\n",
    "\n",
    "Denote by $L(w)$\n",
    "the number of nodes (including both ends)\n",
    "on the path\n",
    "from the root node to the leaf node representing word $w$\n",
    "in the binary tree.\n",
    "Let $n(w,j)$ be the $j^\\textrm{th}$ node on this path,\n",
    "with its context word vector being\n",
    "$\\mathbf{u}_{n(w, j)}$.\n",
    "For example,\n",
    "$L(w_3) = 4$ in  :numref:`fig_hi_softmax`.\n",
    "Hierarchical softmax approximates the conditional probability in :eqref:`eq_skip-gram-softmax` as\n",
    "\n",
    "\n",
    "$$P(w_o \\mid w_c) = \\prod_{j=1}^{L(w_o)-1} \\sigma\\left( [\\![  n(w_o, j+1) = \\textrm{leftChild}(n(w_o, j)) ]\\!] \\cdot \\mathbf{u}_{n(w_o, j)}^\\top \\mathbf{v}_c\\right),$$\n",
    "\n",
    "where function $\\sigma$\n",
    "is defined in :eqref:`eq_sigma-f`,\n",
    "and $\\textrm{leftChild}(n)$ is the left child node of node $n$: if $x$ is true, $[\\![x]\\!] = 1$; otherwise $[\\![x]\\!] = -1$.\n",
    "\n",
    "To illustrate,\n",
    "let's calculate\n",
    "the conditional probability\n",
    "of generating word $w_3$\n",
    "given word $w_c$ in :numref:`fig_hi_softmax`.\n",
    "This requires dot products\n",
    "between the word vector\n",
    "$\\mathbf{v}_c$ of $w_c$\n",
    "and\n",
    "non-leaf node vectors\n",
    "on the path (the path in bold in :numref:`fig_hi_softmax`) from the root to $w_3$,\n",
    "which is traversed left, right, then left:\n",
    "\n",
    "\n",
    "$$P(w_3 \\mid w_c) = \\sigma(\\mathbf{u}_{n(w_3, 1)}^\\top \\mathbf{v}_c) \\cdot \\sigma(-\\mathbf{u}_{n(w_3, 2)}^\\top \\mathbf{v}_c) \\cdot \\sigma(\\mathbf{u}_{n(w_3, 3)}^\\top \\mathbf{v}_c).$$\n",
    "\n",
    "Since $\\sigma(x)+\\sigma(-x) = 1$,\n",
    "it holds that\n",
    "the conditional probabilities of\n",
    "generating all the words in\n",
    "dictionary $\\mathcal{V}$\n",
    "based on any word $w_c$\n",
    "sum up to one:\n",
    "\n",
    "$$\\sum_{w \\in \\mathcal{V}} P(w \\mid w_c) = 1.$$\n",
    ":eqlabel:`eq_hi-softmax-sum-one`\n",
    "\n",
    "Fortunately, since $L(w_o)-1$ is on the order of $\\mathcal{O}(\\textrm{log}_2|\\mathcal{V}|)$ due to the binary tree structure,\n",
    "when the dictionary size $\\mathcal{V}$ is huge,\n",
    "the computational cost for  each training step using hierarchical softmax\n",
    "is significantly reduced compared with that\n",
    "without approximate training.\n",
    "\n",
    "## Summary\n",
    "\n",
    "* Negative sampling constructs the loss function by considering mutually independent events that involve both positive and negative examples. The computational cost for training is linearly dependent on the number of noise words at each step.\n",
    "* Hierarchical softmax constructs the loss function using  the path from the root node to the leaf node in the binary tree. The computational cost for training is dependent on the logarithm of the dictionary size at each step.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. How can we sample noise words in negative sampling?\n",
    "1. Verify that :eqref:`eq_hi-softmax-sum-one` holds.\n",
    "1. How to train the continuous bag of words model using negative sampling and hierarchical softmax, respectively?\n",
    "\n",
    "[Discussions](https://discuss.d2l.ai/t/382)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}